import inspect
from abc import ABC, ABCMeta, abstractmethod
from typing import Dict, Tuple

import torch

from trinity.algorithm.key_mapper import ALL_MAPPERS
from trinity.utils.registry import Registry

POLICY_LOSS_FN = Registry("policy_loss_fn")


class PolicyLossFnMeta(ABCMeta):
    """Metaclass for policy loss functions that handles parameter name mapping and filtering."""

    ignore_keys = {"self", "kwargs", "logprob"}  # Keys to exclude from parameter selection

    def __new__(cls, name, bases, dct):
        """
        Metaclass constructor that automatically generates parameter handling logic.

        For example with `PPOPolicyLossFn` class:
        .. code-block:: python
            class PPOPolicyLossFn(PolicyLossFn):
                ...
                def __call__(
                    self,
                    logprob: torch.Tensor,
                    old_logprob: torch.Tensor,
                    action_mask: torch.Tensor,
                    advantages: torch.Tensor,
                    **kwargs,
                ) -> Tuple[torch.Tensor, Dict]:
                    ...

        This metaclass analyzes the __call__ method's parameters to:
        1. Generate _select_keys containing all non-ignored parameters
        2. Create select_keys property that maps parameters to trainer-specific names
        3. Apply decorator to automatically convert input parameter names using the mapper
        """
        signature = inspect.signature(dct["__call__"])
        param_names = [
            key for key in signature.parameters.keys() if key not in PolicyLossFnMeta.ignore_keys
        ]
        dct["_select_keys"] = param_names

        # Property to return trainer-specific parameter names
        def select_keys(self):
            """Returns parameter keys mapped to the specific training framework's naming convention."""
            keys = [self.mapper.from_trinity(key) for key in self._select_keys]
            return keys

        # Decorator to handle parameter name conversion before calling __call__
        def decorator(func):
            def wrapper(self, *args, **kwargs):
                """Filters and converts parameter names according to the training framework's convention."""
                new_kwargs = {}
                for key, value in kwargs.items():
                    key = self.mapper.to_trinity(key)
                    if key == "logprob" or key in self._select_keys:  # remove unused keys
                        new_kwargs[key] = value
                return func(self, *args, **new_kwargs)

            return wrapper

        # Add the property and decorated method to the class
        dct["select_keys"] = property(select_keys)
        dct["__call__"] = decorator(dct["__call__"])
        return super().__new__(cls, name, bases, dct)


class PolicyLossFn(ABC, metaclass=PolicyLossFnMeta):
    """
    Abstract base class for policy loss functions.

    This class provides the interface for implementing different policy gradient loss functions
    while handling parameter name mapping between different training frameworks.
    """

    def __init__(self, backend: str = "verl"):
        """
        Initialize the policy loss function.

        Args:
            backend: The training framework/backend to use (e.g., "verl")
        """
        self.backend = backend
        self.mapper = ALL_MAPPERS[self.backend]

    @abstractmethod
    def __call__(
        self,
        logprob: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict]:
        """
        Calculate the policy loss.

        Args:
            logprob (`torch.Tensor`): The log probability generated by the policy model.

        Kwargs (optional):
            old_logprob (`torch.Tensor`): The log probability generated by the reference model.
            action_mask (`torch.Tensor`): The action mask.
            advantages (`torch.Tensor`): The advantages.
            kwargs (`Dict`): The step-level parameters for calculating the policy loss.

        Returns:
            `torch.Tensor`: Policy loss
            `Dict`: The metrics for logging.
        """

    @classmethod
    @abstractmethod
    def default_args(cls) -> Dict:
        """
        Get default initialization arguments for this loss function.

        Returns:
            `Dict`: The default init arguments for the policy loss function.
        """
